import os
import re
import json
import torch
import gensim
import openai
import tiktoken
import numpy as np
from openai import OpenAI
from deepdiff import DeepDiff
from collections import Counter
from nltk import pos_tag, word_tokenize
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from torch.nn import functional as F
from nltk.stem import WordNetLemmatizer
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics.pairwise import cosine_similarity

from utils.util import read_txt, read_json, write_json
from utils.union_find import UnionFind
from utils.token_count_decorator import token_count_decorator
from planning.src.protocol import Protocol

word2vec_model = gensim.models.KeyedVectors.load_word2vec_format("dataset/GoogleNews-vectors-negative300.bin.gz", binary=True)
lemmatizer = WordNetLemmatizer()
tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')
model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased')

class Metrics:
    def __init__(self, domain, novel_protocol: Protocol, groundtruth_protocol: Protocol, novel_program_type: str) -> None:
        '''
        Args:
            novel_program: dict
            groundtruth_program: dict
            novel_pseudocode_type: str, "dsl" or "pseudocode"
        '''
        if novel_program_type not in ["dsl", "pseudocode", "multi-dsl"]:
            raise ValueError("novel_program_type must be 'dsl' or 'pseudocode' or 'multi-dsl")
        self.domain = domain
        self.novel_protocol = novel_protocol
        self.groundtruth_protocol = groundtruth_protocol
        self.novel_program_type = novel_program_type
        self.operations_sequence_novel = []
        self.operations_sequence_groundtruth = []
        self.program_components_extraction_prompt = read_txt("planning/data/prompt/program_components_extraction.txt")
        self.flowunit_extraction_prompt = read_txt("planning/data/prompt/flowunit_extraction.txt")
        self.same_components_prompt = read_txt("dsl_design/data/prompt/same_component_judgement.txt")
        self.program_devices_extraction_prompt = read_txt("planning/data/prompt/program_devices_extraction.txt")
        self.device_extraction_prompt = read_txt("planning/data/prompt/device_extraction.txt")
        self.same_devices_prompt = read_txt("dsl_design/data/prompt/same_device_judgement.txt")
        self.final_state_description_prompt = read_txt("planning/data/prompt/final_state_description.txt")
        self.uf = UnionFind()
        self.groundtruth_metadata_path = "planning/data/dataset_metadata.json"
        self.groundtruth_metadata = read_json(self.groundtruth_metadata_path)

    def get_metrics(self) -> dict:
        return {
            "dimension_1": self.get_dimension_1(),
            "dimension_2": self.get_dimension_2(),
            "dimension_3": self.get_dimension_3(),
            "dimension_4": self.get_dimension_4(),
            "dimension_5": self.get_dimension_5(),
            "dimension_6": self.get_dimension_6(),
        }
    
    def get_corresponding_program(self):
        if self.novel_program_type == "pseudocode":
            return self.groundtruth_protocol.program
        elif self.novel_program_type == "dsl":
            return self.groundtruth_protocol.dsl_program
        elif self.novel_program_type == "multi-dsl":
            return self.groundtruth_protocol.multi_dsl_program

    def get_dimension_1(self) -> float:
        '''
        IoU on operations.
        '''
        self.operations_sequence_novel = self.get_operations_sequence(self.novel_protocol.program, self.novel_program_type)
        self.operations_sequence_groundtruth = self.get_operations_sequence(self.get_corresponding_program(), self.novel_program_type)

        # print(self.operations_sequence_groundtruth)
        # print(self.operations_sequence_novel)

        return self.__iou(self.operations_sequence_novel, self.operations_sequence_groundtruth)

    def get_dimension_2(self) -> float:
        '''
        IoU on reagents and intermediate products. Use gpt.
        '''
        novel_reagents = self.get_components(program=self.novel_protocol.program, program_type=self.novel_program_type)
        if self.novel_program_type == "pseudocode":
            if not (groundtruth_reagents := self.groundtruth_metadata[self.domain].get(self.groundtruth_protocol.id).get("flowunits", [])):
                groundtruth_reagents = self.get_components(program=self.groundtruth_protocol.program, program_type="pseudocode")
                self.groundtruth_metadata[self.domain].setdefault(self.groundtruth_protocol.id, {})["flowunits"] = groundtruth_reagents
        else:
            groundtruth_reagents = self.get_components(program=self.get_corresponding_program(), program_type="dsl")

        if not novel_reagents or not groundtruth_reagents:
            return 0.0

        merge_result = self.__alias_judgement(list(set(novel_reagents)), list(set(groundtruth_reagents)), entity_type="component")
        reagent_mapping = {original: unified for unified, original_list in merge_result.items() for original in original_list}
        novel_reagents_clean = [reagent_mapping.get(reagent, reagent) for reagent in novel_reagents]
        groundtruth_reagents_clean = [reagent_mapping.get(reagent, reagent) for reagent in groundtruth_reagents]
        # print("groundtruth reagents", groundtruth_reagents_clean)
        # print("novel reagents", novel_reagents_clean)
        return self.__iou(novel_reagents_clean, groundtruth_reagents_clean)

    def get_dimension_3(self) -> float:
        '''
        IoU on device types. Use gpt.
        '''
        novel_devices = self.get_devices(program=self.novel_protocol.program, program_type=self.novel_program_type)
        if self.novel_program_type == "pseudocode":
            if not (groundtruth_devices := self.groundtruth_metadata[self.domain].get(self.groundtruth_protocol.id).get("devices", [])):
                groundtruth_devices = self.get_devices(program=self.groundtruth_protocol.program, program_type="pseudocode")
                self.groundtruth_metadata[self.domain].setdefault(self.groundtruth_protocol.id, {})["devices"] = groundtruth_devices
        else:
            groundtruth_devices = self.get_devices(program=self.get_corresponding_program(), program_type="dsl")

        if not novel_devices or not groundtruth_devices:
            return 0.0

        merge_result = self.__alias_judgement(list(set(novel_devices)), list(set(groundtruth_devices)), entity_type="device")
        device_mapping = {original: unified for unified, original_list in merge_result.items() for original in original_list}
        novel_devices_clean = [device_mapping.get(device, device) for device in novel_devices]
        groundtruth_devices_clean = [device_mapping.get(device, device) for device in groundtruth_devices]
        # print("groundtruth devices", groundtruth_devices_clean)
        # print("novel devices", novel_devices_clean)
        return self.__iou(novel_devices_clean, groundtruth_devices_clean)

    def get_dimension_4(self) -> float:
        '''
        Similarity between the execution sequences.
        '''
        return seqAlign(self.operations_sequence_novel, self.operations_sequence_groundtruth)

    def get_dimension_5(self) -> float:
        '''
        Similarity between final states (output of the experiment, namely description of the last operation and the final product). Use gpt.
        '''
        novel_prompt = self.final_state_description_prompt.replace("---PSEUDOCODE---", json.dumps(self.novel_protocol.program, indent=4, ensure_ascii=False))
        novel_final_description = self.__chatgpt_function(novel_prompt)
        groundtruth_prompt = self.final_state_description_prompt.replace("---PSEUDOCODE---", json.dumps(self.get_corresponding_program(), indent=4, ensure_ascii=False))
        groundtruth_final_description = self.__chatgpt_function(groundtruth_prompt)

        novel_embeddings = self.__get_openai_embedding(novel_final_description)
        groundtruth_embeddings = self.__get_openai_embedding(groundtruth_final_description)

        cos_sim = self.__vector_cos_sim(novel_embeddings, groundtruth_embeddings)
        return cos_sim

    def get_dimension_6(self) -> float:
        '''
        Parameter-level similarity (JSON file-level BLEU score)
        '''
        # scores = self.compare_jsons_with_bleu(self.novel_program, self.groundtruth_program)
        novel_embeddings = self.__get_openai_embedding(json.dumps(self.novel_protocol.program, indent=4, ensure_ascii=False))
        groundtruth_embeddings = self.__get_openai_embedding(json.dumps(self.get_corresponding_program(), indent=4, ensure_ascii=False))
        return self.__vector_cos_sim(novel_embeddings, groundtruth_embeddings)
        
    def get_operations_sequence(self, program: dict, program_type: str) -> list:
        '''
        Args:
            program: dict, protocol pseudocode
            program_type: str, "dsl" or "pseudocode" or "multi-dsl"
        Returns:
            list, operations sequence
        '''
        operations_sequence = []
        if program_type in ["dsl", "multi-dsl"]:
            for step_dic in program:
                if "Operation" in step_dic:
                    first_verb = self.__get_first_verb(step_dic["Operation"])
                    if first_verb:
                        operations_sequence.append(first_verb)
        elif program_type == "pseudocode":
            for func_name in program.keys():
                first_verb = self.__get_first_verb(func_name)
                if first_verb:
                    operations_sequence.append(first_verb)
        return operations_sequence
    
    def get_components(self, program: dict, program_type: str) -> list:
        components = []
        if program_type in ["dsl", "multi-dsl"]:
            multi = any("FlowUnit" in step for step in program)
            for step in program:
                if multi and "FlowUnit" in step:
                    components.append(step["FlowUnit"]["Component"])
                elif not multi:
                    try:
                        components.extend(step["Precond"]["SlotArg"])
                    except:
                        continue
            return components
        
        elif program_type == "pseudocode":
            prompt = self.program_components_extraction_prompt.replace("---PSEUDOCODE---", json.dumps(program, indent=4, ensure_ascii=False))
            for _ in range(5):
                response = self.__chatgpt_function(prompt)
                flowunits = [flowunit.strip() for flowunit in response.split(",") if flowunit.strip()]
                if "NONE" in flowunits or not flowunits:
                    return []
                return flowunits
        
        elif program_type == "groundtruth":
            sentence_list = self.__convert_to_sentence_list(program)
            for sentence in sentence_list:
                flowunits = self.__flowunit_extraction(sentence)
                if "NONE" not in flowunits:
                    components.extend(flowunits)
            return components
    
    def get_devices(self, program: dict, program_type: str) -> list:
        devices = []
        if program_type in ["dsl", "multi-dsl"]:
            for step in program:
                if "Execution" in step:
                    if isinstance(step["Execution"], dict):
                        if isinstance(step["Execution"]["DeviceType"], str):
                            devices.append(step["Execution"]["DeviceType"])
                    elif isinstance(step["Execution"], list):
                        for device_dict in step["Execution"]:
                            if isinstance(device_dict["DeviceType"], str):
                                devices.append(device_dict["DeviceType"])
            return devices
        
        elif program_type == "pseudocode":
            prompt = self.program_devices_extraction_prompt.replace("---PSEUDOCODE---", json.dumps(program, indent=4, ensure_ascii=False))
            for _ in range(5):
                response = self.__chatgpt_function(prompt)
                devices = [device.strip() for device in response.split(",") if device.strip()]
                if "NONE" in devices or not devices:
                    return []
                return devices
        
        elif program_type == "groundtruth":
            sentence_list = self.__convert_to_sentence_list(program)
            for sentence in sentence_list:
                device_list = self.__device_extraction(sentence)
                if "NONE" not in device_list:
                    devices.extend(device_list)
            return devices
    
    def compare_jsons_with_bleu(self, json1, json2):
        differences = DeepDiff(json1, json2, ignore_order=True)  # Get diffs, ignoring order of elements in lists
        bleu_scores = []

        # Extract modified items from the differences
        if 'values_changed' in differences:
            for path, change in differences['values_changed'].items():
                old_value = str(change['old_value'])
                new_value = str(change['new_value'])

                # Tokenize the values as sentences
                reference = word_tokenize(old_value)  # Treat old_value as reference
                candidate = word_tokenize(new_value)  # Treat new_value as candidate

                # Calculate BLEU score between the reference and candidate
                bleu_score = sentence_bleu([reference], candidate, smoothing_function=SmoothingFunction().method1)
                bleu_scores.append((path, bleu_score))

        return bleu_scores

    def __get_first_verb(self, operation_str):
        tokens = re.split(r'[_ ]', operation_str)
        lemmatized_tokens = [lemmatizer.lemmatize(token.lower(), pos="v") for token in tokens]
        pos_tags = pos_tag(lemmatized_tokens)  # 进行词性标注
        
        for word, pos in pos_tags:
            if pos.startswith('VB'):  # VB, VBD, VBG, VBN, VBP, VBZ
                return word
        return lemmatized_tokens[0]
    
    def __convert_to_sentence_list(self, steps):
        if not steps:
            return []
        sentences = [sentence.strip() for sentence in steps.split("\n") if sentence.strip()]
        operation_steps = [sentence for sentence in sentences if re.match(r'^\d+\.', sentence)]
        return operation_steps

    def __flowunit_extraction(self, sentence):
        prompt = self.flowunit_extraction_prompt.replace("---SENTENCES---", sentence)
        for _ in range(5):
            response = self.__chatgpt_function(prompt).strip()
            if "NONE" in response:
                return ["NONE"]
            return [flowunit.strip() for flowunit in response.split(",") if flowunit.strip()]
        
    def __device_extraction(self, sentence):
        prompt = self.device_extraction_prompt.replace("---SENTENCES---", sentence)
        for _ in range(5):
            response = self.__chatgpt_function(prompt).strip()
            if "NONE" in response:
                return ["NONE"]
            return [device.strip() for device in response.split(",") if device.strip()]

    def __iou(self, seq1, seq2):
        """
        计算两个序列的 IoU
        """
        counter1 = Counter(seq1)
        counter2 = Counter(seq2)
        intersection = sum((counter1 & counter2).values())
        union = sum((counter1 | counter2).values())
        return intersection / union if union != 0 else 0
    
    def __alias_judgement(self, match_1: list, match_2: list, entity_type: str) -> dict:
        homo_candidates = []
        match_1_emb = np.array([self.__get_scibert_embedding(entity) for entity in match_1])
        match_2_emb = np.array([self.__get_scibert_embedding(entity) for entity in match_2])
        cos_matrix = cosine_similarity(match_1_emb, match_2_emb)
        for i, entity_1 in enumerate(match_1):
            for j, entity_2 in enumerate(match_2):
                if cos_matrix[i, j] > 1.5 and cos_matrix[i, j] < 2:
                    homo_candidates.append((entity_1, entity_2))
        # print(f"{entity_type}: candidate finish", len(homo_candidates))
        same_pairs = []
        for homo_pair in homo_candidates:
            if entity_type == "component":
                prompt = self.same_components_prompt.replace("---TARGET---", str(homo_pair))
            elif entity_type == "device":
                prompt = self.same_devices_prompt.replace("---TARGET---", str(homo_pair))
            for _ in range(5):
                response = self.__chatgpt_function(content=prompt)
                if response.strip() in ["Yes", "No"]:
                    break
            else:
                raise RuntimeError(f"Failed to fetch a valid response after 5 attempts. Last response was: {response}")
            if response.strip() == "Yes":
                same_pairs.append(homo_pair)
        
        return self.__merge_same(same_pairs)

    def __merge_same(self, same_pairs):
        for pair in same_pairs:
            entity1, entity2 = pair
            self.uf.add(entity1)
            self.uf.add(entity2)
            self.uf.union(entity1, entity2)
        
        merged_entities = {}
        for entity in self.uf.parent:
            root = self.uf.find(entity)
            merged_entities.setdefault(root, []).append(entity)
        entities_list = list(merged_entities.values())
        
        merge_result = {entities[0]: entities for entities in entities_list}
        return merge_result

    @token_count_decorator(flow="together", batch=False)
    def __chatgpt_function(self, content, gpt_model="gpt-4o-mini"):
        while True:
            try:
                client = OpenAI(
                    api_key=os.environ.get("OPENAI_API_KEY"),
                )
                chat_completion = client.chat.completions.create(
                    messages=[
                        {"role": "user", "content": content}
                    ],
                    model=gpt_model
                )
                return chat_completion.choices[0].message.content
            except openai.APIError as error:
                print(error)

    @token_count_decorator(model="text-embedding-3-large", batch=False)
    def __get_openai_embedding(self, text, model="text-embedding-3-large", max_tokens=8191):
        tokenizer = tiktoken.encoding_for_model(model)  # 获取模型的编码器

        # 计算输入文本的 token 数量
        tokens = tokenizer.encode(text)
        
        # 如果 token 数量超过 max_tokens，则截断文本
        if len(tokens) > max_tokens:
            tokens = tokens[:max_tokens]  # 截断到 max_tokens
            text = tokenizer.decode(tokens)  # 重新解码截断后的 tokens

        while True:
            try:
                client = openai.OpenAI(
                    api_key=os.environ.get("OPENAI_API_KEY"),
                )
                text = text.replace("\n", " ")
                return client.embeddings.create(input=[text], model=model).data[0].embedding
            except openai.APIError as error:
                print(error)
    
    def __get_scibert_embedding(self, text):
        inputs = tokenizer(text, return_tensors='pt')
        with torch.no_grad():
            outputs = model(**inputs)
        embedding = outputs.last_hidden_state[:, 0, :]
        embedding = F.normalize(embedding, p=2, dim=1)
        return embedding.squeeze().numpy()
    
    def __vector_cos_sim(self, vec1, vec2):
        """
        计算两个向量的余弦相似度
        """
        try:
            return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
        except Exception as _:
            return 0.0
    
    def __dump_groundtruth_metadata(self):
        write_json(self.groundtruth_metadata_path, self.groundtruth_metadata)


# 两个单词序列的相似度，[0, 1]
def seqAlign(tra1, tra2):
    '''
    similarity of two word sequence, return float: [0, 1]
    '''
    if len(tra1) == len(tra2):
        max_score = __needleman_wunsch(tra1, tra2)
    else:
        max_score = __smith_waterman(tra1, tra2)
    return 2 * max_score / (len(tra1) + len(tra2))

def __needleman_wunsch(seq1, seq2, match_score=1):
    # 创建得分矩阵
    rows, cols = len(seq1) + 1, len(seq2) + 1
    score_matrix = [[0] * cols for _ in range(rows)]
    # 初始化第一行和第一列
    for i in range(1, rows):
        score_matrix[i][0] = score_matrix[i - 1][0] - word_cos_sim(seq1[i - 1], seq2[0])
    for j in range(1, cols):
        score_matrix[0][j] = score_matrix[0][j - 1] - word_cos_sim(seq1[0], seq2[j - 1])
    # 填充得分矩阵
    for i in range(1, rows):
        for j in range(1, cols):
            mismatch_penalty = word_cos_sim(seq1[i - 1], seq2[j - 1])
            delete_penalty = -word_cos_sim(seq1[i - 1], seq2[j]) if j != cols - 1 else -1
            insert_penalty = -word_cos_sim(seq1[i], seq2[j - 1]) if i != rows - 1 else -1
            match = score_matrix[i - 1][j - 1] + (match_score if seq1[i - 1] == seq2[j - 1] else mismatch_penalty)
            delete = score_matrix[i - 1][j] + delete_penalty
            insert = score_matrix[i][j - 1] + insert_penalty
            score_matrix[i][j] = max(match, delete, insert)
    return score_matrix[rows - 1][cols - 1]

def __smith_waterman(seq1, seq2, match_score=1):
    # 创建得分矩阵
    rows, cols = len(seq1) + 1, len(seq2) + 1
    score_matrix = [[0] * cols for _ in range(rows)]
    # 记录最高分值
    max_score = 0
    # 填充得分矩阵
    for i in range(1, rows):
        for j in range(1, cols):
            mismatch_penalty = word_cos_sim(seq1[i - 1], seq2[j - 1])
            delete_penalty = -word_cos_sim(seq1[i - 1], seq2[j]) if j != cols - 1 else -1
            insert_penalty = -word_cos_sim(seq1[i], seq2[j - 1]) if i != rows - 1 else -1
            match = score_matrix[i - 1][j - 1] + (match_score if seq1[i - 1] == seq2[j - 1] else mismatch_penalty)
            delete = max(score_matrix[i - 1][j] + delete_penalty, 0)
            insert = max(score_matrix[i][j - 1] + insert_penalty, 0)
            score_matrix[i][j] = max(match, delete, insert)
            if score_matrix[i][j] > max_score:
                max_score = score_matrix[i][j]
    return max_score

def word_cos_sim(word_a, word_b):
    a = lemmatizer.lemmatize(word_a.lower(), pos="v")
    b = lemmatizer.lemmatize(word_b.lower(), pos="v")
    try:
        return word2vec_model.similarity(a, b)
    except:
        return 0.0
